{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "h2c6PyqQaNiA"
      },
      "source": [
        "# LIT Standalone Components\n",
        "\n",
        "This notebook shows use of the [Learning Interpretability Tool](https://pair-code.github.io/lit) components on a binary classifier for labelling statement sentiment (0 for negative, 1 for positive).\n",
        "\n",
        "All LIT backend components (models, datasets, metrics, generators, etc.) are standalone Python classes, and can easily be used from Colab or another Python context without starting a server. This can be handy for development, of if you want to re-use components in an offline workflow.\n",
        "\n",
        "Copyright 2021 Google LLC.\n",
        "SPDX-License-Identifier: Apache-2.0"
      ]
    },
    {
      "metadata": {
        "id": "mxeB9Hmtq8QP"
      },
      "cell_type": "code",
      "source": [
        "# The pip installation will install all necessary prerequisite packages for use of the core LIT package.\n",
        "!pip install lit-nlp"
      ],
      "outputs": [],
      "execution_count": null
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "wADoVZ1Nq8QP"
      },
      "outputs": [],
      "source": [
        "import attr\n",
        "import pandas as pd\n",
        "\n",
        "from lit_nlp import notebook\n",
        "from lit_nlp.examples.glue import data\n",
        "from lit_nlp.examples.glue import models\n",
        "\n",
        "# Hide INFO and lower logs. Comment this out for debugging.\n",
        "from absl import logging\n",
        "logging.set_verbosity(logging.WARNING)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "x38BqRdJFlyW"
      },
      "source": [
        "## Load data\n",
        "\n",
        "LIT's `Dataset` classes are just lists of records, plus spec information to describe each field."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "AWhbAZg57RpB",
        "outputId": "f34dc6b4-9fe4-4d00-f987-a7660b96cd15"
      },
      "outputs": [],
      "source": [
        "sst_data = data.SST2Data('validation')\n",
        "sst_data.spec()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "9GSfs1waBdLd",
        "outputId": "fcbb376d-dab5-4bae-b4f0-191643683c4a",
        "scrolled": false
      },
      "outputs": [],
      "source": [
        "sst_data.examples[:10]"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "7_hlS2eAFtxu"
      },
      "source": [
        "You can easily convert this to tabular form, too:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 414
        },
        "id": "fW6GyeJ8FrkB",
        "outputId": "2f397e2b-9eb6-417d-9c90-738c3bfb7c05"
      },
      "outputs": [],
      "source": [
        "pd.DataFrame(sst_data.examples)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "YZeDs3bOFygS"
      },
      "source": [
        "## Load a model and run inference\n",
        "\n",
        "LIT's `Model` class defines a `predict()` function to perform inference. The `input_spec()` describes the expected inputs (it should be a subset of the dataset fields), and `output_spec()` describes the output."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "30l9ZyTjxJjf",
        "outputId": "cda76ba1-a403-4f5a-e753-18dddef50107"
      },
      "outputs": [],
      "source": [
        "# Fetch the trained model weights and load the model to analyze\n",
        "!wget https://storage.googleapis.com/what-if-tool-resources/lit-models/sst2_tiny.tar.gz\n",
        "!mkdir sst2_tiny\n",
        "!tar -xvf sst2_tiny.tar.gz -C sst2_tiny\n",
        "\n",
        "sentiment_model = models.SST2Model('./sst2_tiny')\n",
        "sentiment_model.input_spec(), sentiment_model.output_spec()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "KKcqVijmGPLH"
      },
      "source": [
        "There's a lot of fields in the output spec, since this model returns embeddings, gradients, attention, and more. We can view it using Pandas to avoid too much clutter:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 872
        },
        "id": "-KdWHCUHGP4o",
        "outputId": "2a16ae42-51a1-4285-ddae-21bb2b704fa9"
      },
      "outputs": [],
      "source": [
        "preds = list(sentiment_model.predict(sst_data.examples[:10]))\n",
        "pd.DataFrame(preds)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "kf_cufk9GXRp"
      },
      "source": [
        "If we just want the predicted probabilites for each class, we can look at the `probas` field:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 355
        },
        "id": "ZppG2HSrGXyc",
        "outputId": "560714fb-f43c-434e-b226-792e9d04b2de"
      },
      "outputs": [],
      "source": [
        "labels = sentiment_model.output_spec()['probas'].vocab\n",
        "pd.DataFrame([p['probas'] for p in preds], columns=pd.Index(labels, name='label'))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "555KTfI8GcHN"
      },
      "source": [
        "## Salience methods\n",
        "\n",
        "We can use different interpretability components as well. Here's an example running LIME to get a salience map. The output has entries for each input field, though here that's just one field named \"sentence\":"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "3iW42OVpGcme",
        "outputId": "4dd41374-18d0-4f0e-c813-d91a1667aad2"
      },
      "outputs": [],
      "source": [
        "from lit_nlp.components import lime_explainer\n",
        "lime = lime_explainer.LIME()\n",
        "\n",
        "lime_results = lime.run(sst_data.examples[:1], sentiment_model, sst_data)[0]\n",
        "lime_results"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 324
        },
        "id": "ludO8l76Geyv",
        "outputId": "900408e3-c6c9-4f3b-d49a-4107b38ea827"
      },
      "outputs": [],
      "source": [
        "# Again, pretty-print output with Pandas. The SalienceMap object is just a dataclass defined using attr.s.\n",
        "pd.DataFrame(attr.asdict(lime_results['sentence']))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "RosugrhIGifH",
        "outputId": "562df56b-5b07-407e-a314-5e0b23c1e006"
      },
      "outputs": [],
      "source": [
        "from lit_nlp.components import gradient_maps\n",
        "ig = gradient_maps.IntegratedGradients()\n",
        "\n",
        "ig_results = ig.run(sst_data.examples[:1], sentiment_model, sst_data)[0]\n",
        "ig_results"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 355
        },
        "id": "v1ekBujJGk_I",
        "outputId": "20fb48a1-8251-4086-922a-c9a4a587515f"
      },
      "outputs": [],
      "source": [
        "# Again, pretty-print output with Pandas. The SalienceMap object is just a dataclass defined using attr.s.\n",
        "pd.DataFrame(attr.asdict(ig_results['token_grad_sentence']))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "kRe6I5y1GlQv"
      },
      "source": [
        "## Metrics\n",
        "\n",
        "We can also compute metrics. The metrics components (via the `SimpleMetrics` API) will automatically detect compatible fields marked by the `parent` attribute - in this case, our model's `probas` field that should be scored against `label` in the input."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "Av16xG_8Gn4y",
        "outputId": "37fa6860-e8a9-424b-dfb1-6920dee92de5",
        "scrolled": false
      },
      "outputs": [],
      "source": [
        "from lit_nlp.components import metrics\n",
        "classification_metrics = metrics.MulticlassMetrics()\n",
        "classification_metrics.run(sst_data.examples[:100], sentiment_model, sst_data)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "8dKD5QKRGuOG"
      },
      "source": [
        "## Generators\n",
        "\n",
        "We can use counterfactual generators as well. Here's an example with a generator that simply scrambles words in a text segment."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "dHAv4C03GuvU",
        "outputId": "58f05e89-9c1b-4982-8535-989eae92e58e"
      },
      "outputs": [],
      "source": [
        "from lit_nlp.components import scrambler\n",
        "sc = scrambler.Scrambler()\n",
        "\n",
        "sc_in = sst_data.examples[:5]\n",
        "sc_out = sc.generate_all(sc_in, model=None, dataset=sst_data,\n",
        "                         config={'Fields to scramble': ['sentence']})\n",
        "# The output is a list-of-lists, generated from each original example.\n",
        "sc_out"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 202
        },
        "id": "5jCaQkpTGuzd",
        "outputId": "b89c7d7e-2379-443c-c36c-f910444aa1fc"
      },
      "outputs": [],
      "source": [
        "# Format as a flat table for display, including original sentences\n",
        "import itertools\n",
        "for ex_in, exs_out in zip(sc_in, sc_out):\n",
        "  for ex_out in exs_out:\n",
        "    ex_out['original_sentence'] = ex_in['sentence']\n",
        "pd.DataFrame(itertools.chain.from_iterable(sc_out), columns=['original_sentence', 'sentence', 'label'])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "XGZIbOy9IuBT"
      },
      "source": [
        "# Running the LIT UI\n",
        "\n",
        "Of course, you can always still use these components in the LIT UI, without leaving Colab."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "hLCOC1hmGu4x"
      },
      "outputs": [],
      "source": [
        "widget = notebook.LitWidget(models={'sentiment': sentiment_model},\n",
        "                            datasets={'sst2': sst_data}, port=8890)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "z0DKzSjtIzWd"
      },
      "outputs": [],
      "source": [
        "widget.render(height=600)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "wNxmBPr2q8QP"
      },
      "source": [
        "If you've found interesting examples using the LIT UI, you can access these in Python using `widget.ui_state`:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "kSGtVXRtq8QP"
      },
      "outputs": [],
      "source": [
        "widget.ui_state.primary  # the main selected datapoint"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "FmyzL26Vq8QP"
      },
      "outputs": [],
      "source": [
        "widget.ui_state.selection  # the full selected set, if you have multiple points selected"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ktP0CWHgq8QP"
      },
      "outputs": [],
      "source": [
        "widget.ui_state.pinned  # the pinned datapoint, if you use the 📌 icon or comparison mode"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Ntxyncu5q8QP"
      },
      "source": [
        "Note that these include some metadata; the bare example is in the `['data']` field for each record:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "AwGRITwGq8QP"
      },
      "outputs": [],
      "source": [
        "widget.ui_state.primary['data']"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "5eUVaK2Eq8QP"
      },
      "outputs": [],
      "source": [
        "[ex['data'] for ex in widget.ui_state.selection]"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "7bF6lmLfq8QP"
      },
      "outputs": [],
      "source": []
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [],
      "name": "LIT Components Example",
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3 (ipykernel)",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.9.15"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 1
}
